{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
import math
import typing as T

import numpy

{# If a pose type, include the necessary rotation type. #}
{% if imported_classes is defined %}
    {% for imported_cls in imported_classes %}
from .{{ camelcase_to_snakecase(imported_cls.__name__ )}} import {{ imported_cls.__name__ }}
    {% endfor %}
{% endif -%}

# isort: split
from .ops import {{ camelcase_to_snakecase(cls.__name__) }} as ops


class {{ cls.__name__ }}(object):
    {% if doc %}
    """
    Autogenerated Python implementation of {{ cls }}.

    {% for line in doc.split('\n') %}
    {{ line.rstrip() }}
    {% endfor %}
    """
    {% endif %}

    __slots__ = ['data']

    def __repr__(self):
        # type: () -> str
        return '<{} {}>'.format(self.__class__.__name__, self.data)

    {% set custom_template_name = "custom_methods/{}.py.jinja".format(cls.__name__.lower()) %}
    # --------------------------------------------------------------------------
    # Handwritten methods included from "{{ custom_template_name }}"
    # --------------------------------------------------------------------------

    {% include custom_template_name %}

    {% if custom_generated_methods %}
    {%- import "../util/util.jinja" as util with context -%}
    # --------------------------------------------------------------------------
    # Custom generated methods
    # --------------------------------------------------------------------------

    {% endif %}

    {% for spec in custom_generated_methods %}
        {% set available_classes = [cls] %}
        {% if imported_classes is defined %}
            {% set available_classes = available_classes + imported_classes %}
        {% endif %}
    {{ util.function_declaration(spec, is_method=True, available_classes=available_classes) | indent(4)}}
        {{ util.print_docstring(spec.docstring) | indent(8) }}

        {{ util.expr_code(spec, available_classes=available_classes) | indent(4) }}

    {% endfor %}

    # --------------------------------------------------------------------------
    # StorageOps concept
    # --------------------------------------------------------------------------

    @staticmethod
    def storage_dim():
        # type: () -> int
        return {{ ops.StorageOps.storage_dim(cls) }}

    def to_storage(self):
        # type: () -> T.List[float]
        return list(self.data)

    @classmethod
    def from_storage(cls, vec):
        # type: (T.Sequence[float]) -> {{ cls.__name__ }}
        instance = cls.__new__(cls)

        if isinstance(vec, list):
            instance.data = vec
        else:
            instance.data = list(vec)

        if len(vec) != cls.storage_dim():
            raise ValueError(
                "{} has storage dim {}, got {}.".format(cls.__name__, cls.storage_dim(), len(vec))
            )

        return instance

    # --------------------------------------------------------------------------
    # GroupOps concept
    # --------------------------------------------------------------------------

    @classmethod
    def identity(cls):
        # type: () -> {{ cls.__name__ }}
        return ops.GroupOps.identity()

    def inverse(self):
    # type: () -> {{ cls.__name__ }}
        return ops.GroupOps.inverse(self)

    def compose(self, b):
        # type: ({{ cls.__name__ }}) -> {{ cls.__name__ }}
        return ops.GroupOps.compose(self, b)

    def between(self, b):
        # type: ({{ cls.__name__ }}) -> {{ cls.__name__ }}
        return ops.GroupOps.between(self, b)

    {% if is_lie_group %}
    # --------------------------------------------------------------------------
    # LieGroupOps concept
    # --------------------------------------------------------------------------

    @staticmethod
    def tangent_dim():
        # type: () -> int
        return {{ ops.LieGroupOps.tangent_dim(cls) }}

    @classmethod
    def from_tangent(cls, vec, epsilon=1e-8):
        # type: (numpy.ndarray, float) -> {{ cls.__name__ }}
        if len(vec) != cls.tangent_dim():
            raise ValueError(
                "Vector dimension ({}) not equal to tangent space dimension ({}).".format(
                    len(vec), cls.tangent_dim()
                )
            )
        return ops.LieGroupOps.from_tangent(vec, epsilon)

    def to_tangent(self, epsilon=1e-8):
        # type: (float) -> numpy.ndarray
        return ops.LieGroupOps.to_tangent(self, epsilon)

    def retract(self, vec, epsilon=1e-8):
        # type: (numpy.ndarray, float) -> {{ cls.__name__ }}
        if len(vec) != self.tangent_dim():
            raise ValueError(
                "Vector dimension ({}) not equal to tangent space dimension ({}).".format(
                    len(vec), self.tangent_dim()
                )
            )
        return ops.LieGroupOps.retract(self, vec, epsilon)

    def local_coordinates(self, b, epsilon=1e-8):
        # type: ({{ cls.__name__ }}, float) -> numpy.ndarray
        return ops.LieGroupOps.local_coordinates(self, b, epsilon)

    def interpolate(self, b, alpha, epsilon=1e-8):
        # type: ({{ cls.__name__ }}, float, float) -> {{ cls.__name__ }}
        return ops.LieGroupOps.interpolate(self, b, alpha, epsilon)
    {% endif %}

    # --------------------------------------------------------------------------
    # General Helpers
    # --------------------------------------------------------------------------
    def __eq__(self, other):
        # type: (T.Any) -> bool
        if isinstance(other, {{ cls.__name__ }}):
            return self.data == other.data
        else:
            return False

    @T.overload
    def __mul__(self, other):  # pragma: no cover
        # type: ({{ cls.__name__ }}) -> {{ cls.__name__ }}
        pass

    @T.overload
    def __mul__(self, other):  # pragma: no cover
        # type: (numpy.ndarray) -> numpy.ndarray
        pass

    def __mul__(self, other):
        # type: (T.Union[{{ cls.__name__ }}, numpy.ndarray]) -> T.Union[{{ cls.__name__ }}, numpy.ndarray]
        if isinstance(other, {{ cls.__name__ }}):
            return self.compose(other)
        elif isinstance(other, numpy.ndarray) and hasattr(self, "compose_with_point"):
            return getattr(self, "compose_with_point")(other).reshape(other.shape)
        else:
            raise NotImplementedError('Cannot compose {} with {}.'.format(type(self), type(other)))
